# evaluate a smoothed classifier on a dataset
import argparse
import datetime
import os
from time import time

from architectures import get_architecture
from core import Smooth
from datasets import get_dataset, DATASETS, get_num_classes
import torch


parser = argparse.ArgumentParser(description='Certify many examples')
parser.add_argument("dataset", choices=DATASETS, help="which dataset")
parser.add_argument("--base_classifier", type=str, default="models/checkpoint.pth.tar", help="path to saved pytorch model of base classifier")
parser.add_argument("--sigma", type=float, default=0.5, help="noise hyperparameter")
parser.add_argument("--sigma_max", type=float, default=1.0, help="noise hyperparameter")
parser.add_argument("--sigma_min", type=float, default=0.05, help="noise hyperparameter")
parser.add_argument("--sigma_interval", type=float, default=0.05, help="noise hyperparameter")
parser.add_argument("--outfile", type=str, default="out", help="output file")
parser.add_argument("--name", type=str, default="sigma")
parser.add_argument("--batch", type=int, default=400, help="batch size")
parser.add_argument("--skip", type=int, default=5, help="how many examples to skip")
parser.add_argument("--min", type=int, default=0, help="begin on this example")
parser.add_argument("--max", type=int, default=-1, help="stop after this many examples")
parser.add_argument("--split", choices=["train", "train-test", "test"], default="test", help="train or test set")
parser.add_argument("--N0", type=int, default=100)
parser.add_argument("--N", type=int, default=1000, help="number of samples to use")
parser.add_argument("--alpha", type=float, default=0.001, help="failure probability")

args = parser.parse_args()

if __name__ == "__main__":
    # load the base classifier
    checkpoint = torch.load(args.base_classifier)
    base_classifier = get_architecture(checkpoint["arch"], args.dataset)
    base_classifier.load_state_dict(checkpoint['state_dict'], False)


    name = args.name

    dir_list = args.base_classifier.split("/")
    outfile = ""
    for i in range(len(dir_list) - 1):
        outfile = outfile + dir_list[i] + "/"
    outfile = outfile

    # prepare output file
    f = open(outfile + "result-" + name, 'w')
    print("idx\tlabel\tpredict\tradius\tcorrect\ttime\tsigma", file=f, flush=True)

    # create the smooothed classifier g
    smoothed_classifier = Smooth(base_classifier, get_num_classes(args.dataset), args.sigma)

    # iterate through the dataset
    dataset = get_dataset(args.dataset, args.split)

    for i in range(len(dataset)):

        # only certify every args.skip examples, and stop after args.max examples
        if i < args.min:
            continue
        if i % args.skip != 0:
            continue
        if i == args.max:
            break
        if i % 100 == 0:
            print("complete ", i)

        (x, label) = dataset[i]
        
        base_classifier.eval()
        prediction = base_classifier(x.repeat((1, 1, 1, 1)).cuda())
        prediction = prediction.cpu().detach().numpy()[0, :]

        sigma = args.sigma_min - args.sigma_interval
        while sigma < args.sigma_max:
            sigma = args.sigma_interval + sigma
         
            before_time = time()
            # certify the prediction of g around x
            x = x.cuda()
            prediction, radius = smoothed_classifier.certify(x, args.N0, args.N, args.alpha, args.batch, sigma)
            after_time = time()
            correct = int(prediction == label)

            time_elapsed = str(datetime.timedelta(seconds=(after_time - before_time)))
            print("{}\t{}\t{}\t{:.3}\t{}\t{}\t{:.3}".format(
                i, label, prediction, radius, correct, time_elapsed, sigma), file=f, flush=True)

    f.close()